import torch
import torchvision.transforms as transforms
import torchvision.datasets as dset
import random

def get_data_ld(dataset, bsize):
  return torch.utils.data.DataLoader(dataset, batch_size=bsize,
                                     shuffle=True, num_workers=2,
                                     pin_memory=True)


def copy_param(dest, model):
  with torch.no_grad():
    for (p1, p2) in zip(dest, model.parameters()):
      p1.copy_(p2.data)


def vec_sub(g1, g2):
  with torch.no_grad():
    for (p1, p2) in zip(g1, g2):
      p1.sub_(p2)


def vec_mul(g1, c):
  with torch.no_grad():
    for p1 in g1:
      p1.mul_(c)


def vec_add(g1, g2):
  with torch.no_grad():
    for (p1, p2) in zip(g1, g2):
      p1.add_(p2)


def vec_round(g):
  with torch.no_grad():
    for p in g:
      p = torch.sign(p) * torch.floor(torch.abs(p))


def get_dic_str(dic):
  dic_str = ''
  for key in dic:
    dic_str += f'{key}:{str(dic[key])}, '
  return dic_str


def vec_norm2(g):
  with torch.no_grad():
    return float(sum([torch.norm(p) ** 2. for p in g]))


def print_log(log_info_dic, fname):
  log_str = get_dic_str(log_info_dic)
  print(log_str)
  open(fname, 'a+').write(log_str + '\n')


def zero_grad(model, loss_fn, x0, y0):
  loss_fn(model(x0.cuda()), y0.cuda()).backward()
  model.zero_grad()


def accuracy(data_ld, model):
  model.eval()
  total, cnt = 0, 0
  with torch.no_grad():
    for (x, y) in data_ld:
      x = torch.autograd.Variable(x.cuda())
      y = torch.autograd.Variable(y.cuda())
      y_out = model(x).data
      y_pred = torch.argmax(y_out, dim=1)
      total += float(torch.sum(y_pred.eq(y)))
      cnt += float(y.size(0))
  model.train()
  return total / cnt


def number_of_parameter(model):
  d = 0
  for p in model.parameters():
    d += p.data.numel()
  return d


def clone_param(model):
  with torch.no_grad():
    return [p.data.clone() for p in model.parameters()]

def random_label(dataset, p=0):
  index = random.sample(list(range(len(dataset))), int(p * len(dataset)))
  for i in index:
    dataset.targets[i] = random.randint(0, 9)